{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 化合物表示学习和性质预测\n",
    "\n",
    "在这篇教程中，我们将介绍如何运用图神经网络（GNN）模型来预测化合物的性质。具体来说，我们将演示如何对其进行预训练（pretrain），如何针对下游任务进行模型微调（finetune），并利用最终的模型进行推断（inference）。如果你想了解更多细节，请查阅 \"[info graph](https://github.com/PaddlePaddle/PaddleHelix/apps/pretrained_compound/info_graph/README_cn.md)\" 和 \"[pretrained GNN](https://github.com/PaddlePaddle/PaddleHelix/apps/pretrained_compound/pretrain_gnns/README_cn.md)\" 的详细解释.\n",
    "\n",
    "# 第一部分：预训练\n",
    "\n",
    "在这一部分，我们将展示如何预训练一个化合物 GNN 模型。本文中的预训练技术是在预训练 GNN 的基础上发展起来的，包括 attribute masking、context prediction 和有监督预训练。\n",
    "更多细节请查看文件：`pretrain_attrmask.py`和 `pretrain_supervised.py`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.insert(0, os.getcwd() + \"/..\")\n",
    "os.chdir(\"../apps/pretrained_compound/pretrain_gnns\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PaddleHelix 是构建于 PaddlePaddle 之上的生物计算深度学习框架。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-05-08 15:31:52,712 - INFO - ujson not install, fail back to use json instead\n",
      "2021-05-08 15:31:52,801 - INFO - Enabling RDKit 2020.09.1 jupyter extensions\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "import paddle.nn as nn\n",
    "import paddle.distributed as dist\n",
    "import pgl\n",
    "\n",
    "from pahelix.model_zoo.pretrain_gnns_model import PretrainGNNModel, AttrmaskModel\n",
    "from pahelix.datasets.zinc_dataset import load_zinc_dataset\n",
    "from pahelix.utils.splitters import RandomSplitter\n",
    "from pahelix.featurizers.pretrain_gnn_featurizer import AttrmaskTransformFn, AttrmaskCollateFn\n",
    "from pahelix.utils import load_json_config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载配置\n",
    "\n",
    "这里，我们使用 `compound_encoder_config`和`model_config` 保存模型配置。`PretrainGNNModel`是用于预训练gnns的基本GNN模型，`AttrmaskModel` 是一种无监督的预训练模型，它随机地对某个节点的原子类型进行 mask，然后再尝试去预测这个原子的类型。同时，我们使用 Adam 优化器并将学习率（learning rate）设置为 0.001。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bio/.local/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[PretrainGNNModel] embed_dim:300\n",
      "[PretrainGNNModel] dropout_rate:0.5\n",
      "[PretrainGNNModel] norm_type:batch_norm\n",
      "[PretrainGNNModel] graph_norm:False\n",
      "[PretrainGNNModel] residual:False\n",
      "[PretrainGNNModel] layer_num:5\n",
      "[PretrainGNNModel] gnn_type:gin\n",
      "[PretrainGNNModel] JK:last\n",
      "[PretrainGNNModel] readout:mean\n",
      "[PretrainGNNModel] atom_names:['atomic_num', 'chiral_tag']\n",
      "[PretrainGNNModel] bond_names:['bond_dir', 'bond_type']\n"
     ]
    }
   ],
   "source": [
    "compound_encoder_config = load_json_config(\"model_configs/pregnn_paper.json\")\n",
    "model_config = load_json_config(\"model_configs/pre_Attrmask.json\")\n",
    "\n",
    "compound_encoder = PretrainGNNModel(compound_encoder_config)\n",
    "model = AttrmaskModel(model_config, compound_encoder)\n",
    "opt = paddle.optimizer.Adam(0.001, parameters=model.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据集加载和特征提取\n",
    "###用`wget`下载数据集\n",
    "我们首先使用 `wget` 来下载一个小型的测试数据集，如果你的本地计算机上没有 `wget`，你也可以复制下面的链接到你的浏览器中来下载数据。但是请注意你需要把数据包移动到这个路径：\"../apps/pretrained_compound/pretrain_gnns/\"。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-05-08 15:32:00--  https://baidu-nlp.bj.bcebos.com/PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz\n",
      "Resolving baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)... 10.70.0.165\n",
      "Connecting to baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)|10.70.0.165|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 609563 (595K) [application/gzip]\n",
      "Saving to: ‘PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz.2’\n",
      "\n",
      "100%[======================================>] 609,563     --.-K/s   in 0.04s   \n",
      "\n",
      "2021-05-08 15:32:01 (14.0 MB/s) - ‘PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz.2’ saved [609563/609563]\n",
      "\n",
      "tox21  zinc_standard_agent\n"
     ]
    }
   ],
   "source": [
    "### Download a toy dataset for demonstration:\n",
    "!wget \"https://baidu-nlp.bj.bcebos.com/PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz\" --no-check-certificate\n",
    "!tar -zxf \"PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz\"\n",
    "!ls \"./chem_dataset_small\"\n",
    "### Download the full dataset as you want:\n",
    "# !wget \"http://snap.stanford.edu/gnn-pretrain/data/chem_dataset.zip\" --no-check-certificate\n",
    "# !unzip \"chem_dataset.zip\"\n",
    "# !ls \"./chem_dataset\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载数据集并生成特征\n",
    "这里我们采用 Zinc 数据集来进行预训练。这里我们用的是小数据集用作解释，你可以加载全部的数据集。\n",
    "使用 `AttrmaskTransformFn` 来配合模型 `AttrmaskModel`。它用于生成特征，原始特征被处理为网络上可用的特征，例如，smiles变成节点和边特征。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataset num: 1000\n"
     ]
    }
   ],
   "source": [
    "### Load the first 1000 of the toy dataset for speed up\n",
    "dataset = load_zinc_dataset(\"./chem_dataset_small/zinc_standard_agent/\")\n",
    "dataset = dataset[:1000]\n",
    "print(\"dataset num: %s\" % (len(dataset)))\n",
    "\n",
    "transform_fn = AttrmaskTransformFn()\n",
    "dataset.transform(transform_fn, num_workers=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 启动训练\n",
    "\n",
    "现在我们开始训练 Attrmask 模型。我们仅训练两个 epoch 作为演示，在这里，我们使用`AttrmaskTransformFn`将多个样本聚合到一个mini-batch中。数据加载的过程通过4个 `workers` 进行了加速。然后我们将预训练后的模型保存到 \"./model/pretrain_attrmask\"，作为下游任务的初始模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bio/tools/paddle2.0/lib/python3.7/site-packages/paddle/nn/layer/norm.py:648: UserWarning: When training, we now always track global mean and variance.\n",
      "  \"When training, we now always track global mean and variance.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0 train/loss:2.835863\n",
      "epoch:1 train/loss:0.9871801\n"
     ]
    }
   ],
   "source": [
    "def train(model, dataset, collate_fn, opt):\n",
    "    data_gen = dataset.get_data_loader(\n",
    "            batch_size=128, \n",
    "            num_workers=4, \n",
    "            shuffle=True,\n",
    "            collate_fn=collate_fn)\n",
    "    list_loss = []\n",
    "    model.train()\n",
    "    for graphs, masked_node_indice, masked_node_label in data_gen:\n",
    "        graphs = graphs.tensor()\n",
    "        masked_node_indice = paddle.to_tensor(masked_node_indice, 'int64')\n",
    "        masked_node_label = paddle.to_tensor(masked_node_label, 'int64')\n",
    "        loss = model(graphs, masked_node_indice, masked_node_label)\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.clear_grad()\n",
    "        list_loss.append(loss.numpy())\n",
    "    return np.mean(list_loss)\n",
    "\n",
    "collate_fn = AttrmaskCollateFn(\n",
    "        atom_names=compound_encoder_config['atom_names'], \n",
    "        bond_names=compound_encoder_config['bond_names'],\n",
    "        mask_ratio=0.15)\n",
    "\n",
    "for epoch_id in range(2):\n",
    "    train_loss = train(model, dataset, collate_fn, opt)\n",
    "    print(\"epoch:%d train/loss:%s\" % (epoch_id, train_loss))\n",
    "paddle.save(compound_encoder.state_dict(), \n",
    "        './model/pretrain_attrmask/compound_encoder.pdparams')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "模型预训练的内容到此为止，你可以根据自己的需要对上面的参数进行调整。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第二部分：下游任务模型微调（fintune）\n",
    "\n",
    "下面我们将介绍如何对预训练的模型进行微调来适应下游任务。\n",
    "\n",
    "更多细节参见 `finetune.py` 文件中的内容。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pahelix.utils.splitters import \\\n",
    "    RandomSplitter, IndexSplitter, ScaffoldSplitter\n",
    "from pahelix.datasets import *\n",
    "\n",
    "from src.model import DownstreamModel\n",
    "from src.featurizer import DownstreamTransformFn, DownstreamCollateFn\n",
    "from src.utils import calc_rocauc_score, exempt_parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下游任务的数据集通常规模很小，并且面向不同的任务。例如，BBBP 数据集用于预测化合物的血脑屏障通透性；Tox21 数据集用于预测化合物的毒性等。这里我们使用 Tox21 数据集进行演示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']\n"
     ]
    }
   ],
   "source": [
    "task_names = get_default_tox21_task_names()\n",
    "print(task_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载配置\n",
    "\n",
    "我们使用 `compound_encoder_config` 和 `model_config` 来加载模型配置，注意这里的模型结构的设置应该和预训练模型中的设置保持一致，否则模型加载将会失败。\n",
    "`DownstreamModel` 是一个有监督的 GNN 模型，用于上述 `task_names` 中定义的预测任务。\n",
    "\n",
    "同时，我们以BCEloss为准则，使用Adam优化器将lr设置为0.001。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[PretrainGNNModel] embed_dim:300\n",
      "[PretrainGNNModel] dropout_rate:0.5\n",
      "[PretrainGNNModel] norm_type:batch_norm\n",
      "[PretrainGNNModel] graph_norm:False\n",
      "[PretrainGNNModel] residual:False\n",
      "[PretrainGNNModel] layer_num:5\n",
      "[PretrainGNNModel] gnn_type:gin\n",
      "[PretrainGNNModel] JK:last\n",
      "[PretrainGNNModel] readout:mean\n",
      "[PretrainGNNModel] atom_names:['atomic_num', 'chiral_tag']\n",
      "[PretrainGNNModel] bond_names:['bond_dir', 'bond_type']\n"
     ]
    }
   ],
   "source": [
    "compound_encoder_config = load_json_config(\"model_configs/pregnn_paper.json\")\n",
    "model_config = load_json_config(\"model_configs/down_linear.json\")\n",
    "model_config['num_tasks'] = len(task_names)\n",
    "\n",
    "compound_encoder = PretrainGNNModel(compound_encoder_config)\n",
    "model = DownstreamModel(model_config, compound_encoder)\n",
    "criterion = nn.BCELoss(reduction='none')\n",
    "opt = paddle.optimizer.Adam(0.001, parameters=model.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载预训练模型\n",
    "\n",
    "加载预训练阶段得到的模型。这里我们加载模型 \"pretrain_attrmask\" 作为一个例子。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "compound_encoder.set_state_dict(paddle.load('./model/pretrain_attrmask/compound_encoder.pdparams'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据加载和特征提取\n",
    "\n",
    "将 `DownstreamTransformFn` 与 `DownstreamModel` 一起使用。它用于生成特征，原始特征被处理为网络上可用的特征，例如，smiles字符串变成节点和边特征。\n",
    "\n",
    "Tox21 数据集用作下游任务数据集，我们使用 `ScaffoldSplitter` 将数据集拆分为训练/验证/测试集。`ScaffoldSplitter` 首先根据 Bemis-Murcko scaffold 对化合物进行排序，然后从前到后，将参数 `frac_train` 定义的比例的数据作为训练集，将 `frac_valid` 定义的比例的数据作为验证集，其余的作为测试集。`ScaffoldSplitter` 能更好地评价模型对非同分布样本的泛化能力。这里也可以使用其他的拆分器，如 `RandomSplitter`、`RandomScaffoldSplitter` 和 `IndexSplitter`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RDKit WARNING: [15:32:26] WARNING: not removing hydrogen atom without neighbors\n",
      "RDKit WARNING: [15:32:36] WARNING: not removing hydrogen atom without neighbors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train/Valid/Test num: 6264/783/784\n"
     ]
    }
   ],
   "source": [
    "### Load the toy dataset:\n",
    "dataset = load_tox21_dataset(\"./chem_dataset_small/tox21\", task_names)\n",
    "### Load the full dataset:\n",
    "# dataset = load_tox21_dataset(\"./chem_dataset/tox21\", task_names)\n",
    "dataset.transform(DownstreamTransformFn(), num_workers=4)\n",
    "\n",
    "# splitter = RandomSplitter()\n",
    "splitter = ScaffoldSplitter()\n",
    "train_dataset, valid_dataset, test_dataset = splitter.split(\n",
    "        dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1)\n",
    "print(\"Train/Valid/Test num: %s/%s/%s\" % (\n",
    "        len(train_dataset), len(valid_dataset), len(test_dataset)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 启动训练\n",
    "\n",
    "出于演示的目的，这里我们只将 attrmask 模型训练了4轮。在这里，我们使用`DownstreamCollateFn`将多个样本聚合到一个mini-batch中。由于每个下游任务都包含了多个子任务，我们分别计算了每个子任务的 roc-auc，在求其均值作为最后的评估标准。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:0 train/loss:0.26803324\n",
      "epoch:0 val/auc:0.6266238165767454\n",
      "epoch:0 test/auc:0.6216414776039183\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bio/tools/paddle2.0/lib/python3.7/site-packages/paddle/nn/layer/norm.py:648: UserWarning: When training, we now always track global mean and variance.\n",
      "  \"When training, we now always track global mean and variance.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:1 train/loss:0.22468299\n",
      "epoch:1 val/auc:0.6659571793610889\n",
      "epoch:1 test/auc:0.6539646927043956\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bio/tools/paddle2.0/lib/python3.7/site-packages/paddle/nn/layer/norm.py:648: UserWarning: When training, we now always track global mean and variance.\n",
      "  \"When training, we now always track global mean and variance.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:2 train/loss:0.21930875\n",
      "epoch:2 val/auc:0.6822124992478676\n",
      "epoch:2 test/auc:0.6736089418895174\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bio/tools/paddle2.0/lib/python3.7/site-packages/paddle/nn/layer/norm.py:648: UserWarning: When training, we now always track global mean and variance.\n",
      "  \"When training, we now always track global mean and variance.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:3 train/loss:0.21560585\n",
      "epoch:3 val/auc:0.6715751869311037\n",
      "epoch:3 test/auc:0.6290005037852295\n"
     ]
    }
   ],
   "source": [
    "def train(model, train_dataset, collate_fn, criterion, opt):\n",
    "    data_gen = train_dataset.get_data_loader(\n",
    "            batch_size=128, \n",
    "            num_workers=4, \n",
    "            shuffle=True,\n",
    "            collate_fn=collate_fn)\n",
    "    list_loss = []\n",
    "    model.train()\n",
    "    for graphs, valids, labels in data_gen:\n",
    "        graphs = graphs.tensor()\n",
    "        labels = paddle.to_tensor(labels, 'float32')\n",
    "        valids = paddle.to_tensor(valids, 'float32')\n",
    "        preds = model(graphs)\n",
    "        loss = criterion(preds, labels)\n",
    "        loss = paddle.sum(loss * valids) / paddle.sum(valids)\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.clear_grad()\n",
    "        list_loss.append(loss.numpy())\n",
    "    return np.mean(list_loss)\n",
    "\n",
    "def evaluate(model, test_dataset, collate_fn):\n",
    "    data_gen = test_dataset.get_data_loader(\n",
    "            batch_size=128, \n",
    "            num_workers=4, \n",
    "            shuffle=False,\n",
    "            collate_fn=collate_fn)\n",
    "    total_pred = []\n",
    "    total_label = []\n",
    "    total_valid = []\n",
    "    model.eval()\n",
    "    for graphs, valids, labels in data_gen:\n",
    "        graphs = graphs.tensor()\n",
    "        labels = paddle.to_tensor(labels, 'float32')\n",
    "        valids = paddle.to_tensor(valids, 'float32')\n",
    "        preds = model(graphs)\n",
    "        total_pred.append(preds.numpy())\n",
    "        total_valid.append(valids.numpy())\n",
    "        total_label.append(labels.numpy())\n",
    "    total_pred = np.concatenate(total_pred, 0)\n",
    "    total_label = np.concatenate(total_label, 0)\n",
    "    total_valid = np.concatenate(total_valid, 0)\n",
    "    return calc_rocauc_score(total_label, total_pred, total_valid)\n",
    "\n",
    "collate_fn = DownstreamCollateFn(\n",
    "        atom_names=compound_encoder_config['atom_names'], \n",
    "        bond_names=compound_encoder_config['bond_names'])\n",
    "for epoch_id in range(4):\n",
    "    train_loss = train(model, train_dataset, collate_fn, criterion, opt)\n",
    "    val_auc = evaluate(model, valid_dataset, collate_fn)\n",
    "    test_auc = evaluate(model, test_dataset, collate_fn)\n",
    "    print(\"epoch:%s train/loss:%s\" % (epoch_id, train_loss))\n",
    "    print(\"epoch:%s val/auc:%s\" % (epoch_id, val_auc))\n",
    "    print(\"epoch:%s test/auc:%s\" % (epoch_id, test_auc))\n",
    "paddle.save(model.state_dict(), './model/tox21/model.pdparams')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第三部分：下游任务模型预测\n",
    "在这部分，我们将简单介绍如何利用训好的下游任务模型来对给定的 SMILES 序列做预测。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载配置\n",
    "这部分跟第二部分的基本相同。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[PretrainGNNModel] embed_dim:300\n",
      "[PretrainGNNModel] dropout_rate:0.5\n",
      "[PretrainGNNModel] norm_type:batch_norm\n",
      "[PretrainGNNModel] graph_norm:False\n",
      "[PretrainGNNModel] residual:False\n",
      "[PretrainGNNModel] layer_num:5\n",
      "[PretrainGNNModel] gnn_type:gin\n",
      "[PretrainGNNModel] JK:last\n",
      "[PretrainGNNModel] readout:mean\n",
      "[PretrainGNNModel] atom_names:['atomic_num', 'chiral_tag']\n",
      "[PretrainGNNModel] bond_names:['bond_dir', 'bond_type']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bio/.local/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    }
   ],
   "source": [
    "compound_encoder_config = load_json_config(\"model_configs/pregnn_paper.json\")\n",
    "model_config = load_json_config(\"model_configs/down_linear.json\")\n",
    "model_config['num_tasks'] = len(task_names)\n",
    "\n",
    "compound_encoder = PretrainGNNModel(compound_encoder_config)\n",
    "model = DownstreamModel(model_config, compound_encoder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载训练好的下游任务模型\n",
    "加载在第二部分中训练好的下游任务模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.set_state_dict(paddle.load('./model/tox21/model.pdparams'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 开始预测\n",
    "对给定的 SMILES 序列进行预测。我们直接调用 `DownstreamTransformFn` 和 `DownstreamCollateFn` 函数将原始的 SMILES 序列转化为模型的输入。\n",
    "\n",
    "以 Tox21 数据集为例，我们的下游任务模型可以给出 Tox21 里面的12个子任务的预测。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SMILES:O=C1c2ccccc2C(=O)C1c1ccc2cc(S(=O)(=O)[O-])cc(S(=O)(=O)[O-])c2n1\n",
      "Predictions:\n",
      "  NR-AR:\t0.32046446\n",
      "  NR-AR-LBD:\t0.2179917\n",
      "  NR-AhR:\t0.43747446\n",
      "  NR-Aromatase:\t0.3693441\n",
      "  NR-ER:\t0.38625807\n",
      "  NR-ER-LBD:\t0.26430473\n",
      "  NR-PPAR-gamma:\t0.31611276\n",
      "  SR-ARE:\t0.436297\n",
      "  SR-ATAD5:\t0.28199005\n",
      "  SR-HSE:\t0.24803819\n",
      "  SR-MMP:\t0.42311215\n",
      "  SR-p53:\t0.22394522\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bio/tools/paddle2.0/lib/python3.7/site-packages/paddle/nn/layer/norm.py:648: UserWarning: When training, we now always track global mean and variance.\n",
      "  \"When training, we now always track global mean and variance.\")\n"
     ]
    }
   ],
   "source": [
    "SMILES=\"O=C1c2ccccc2C(=O)C1c1ccc2cc(S(=O)(=O)[O-])cc(S(=O)(=O)[O-])c2n1\"\n",
    "transform_fn = DownstreamTransformFn(is_inference=True)\n",
    "collate_fn = DownstreamCollateFn(\n",
    "        atom_names=compound_encoder_config['atom_names'], \n",
    "        bond_names=compound_encoder_config['bond_names'],\n",
    "        is_inference=True)\n",
    "graph = collate_fn([transform_fn({'smiles': SMILES})])\n",
    "preds = model(graph.tensor()).numpy()[0]\n",
    "print('SMILES:%s' % SMILES)\n",
    "print('Predictions:')\n",
    "for name, prob in zip(task_names, preds):\n",
    "    print(\"  %s:\\t%s\" % (name, prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
